Historical Panchromatic Orthophoto Colorisation with a Generative Adversarial Neural net¶
By Kay Warrie
This is deeplearing model to colorize historical greyscale or panchromatic orthophotos and orthophoto mosaics. Greyscale images that made using all the wavelengths of the visible spectrum are called panchromatic, most historical images are panchrommatic.
An orthophoto is an aerial photograph geometrically corrected ("orthorectified") such that the scale is uniform. It is the basis for most mapping solutions.
An orthophoto mosaic is a type of large scale image that is created by stitching together a collection orthophoto to produce a seamless, georeferenced image, for example "Satelite"-view in google maps, that is by thet way, mostly made with aerial photo's and not with satelite images.
A Generative Adversarial Network (GAN) is a type of artificial intelligence, a generative model consisting of two neural networks, the generator and the discriminator.
The generator is convolutional neural net that makes an image and the discriminator is en model the tries to distinguish between the label data and the generated images.
The Loss of GAN's discriminator calculated by passing it a batch of the generators output and a batch of real data and seeing if it can distinguish between the two. The loss of generator is output of the discriminator.
A full explantation how this model was constructed can found in explanation.ipynb.
To train and run the final a series of commandline tools was constructed:
- pretrain_unet.py -> initialise the U-net generator.
- trainWeigthed.py inference
- inference.py -> test the resulting model on real greyscale images.
To showcase the inference results a interactive webpage was constructed:
%%html
<iframe width="100%" height="500" style="border:none; overflow-x: hidden;"
src="https://warrieka.github.io/histo_ortho_viewer/?hidebanner=1" ></iframe>
Context and background:¶
There many historal panchromatic aerial photo's of Belgium, like those made by the National Mapping Agency NGI and the Allies and Axis aerial reconnaissance forces during WOII. Also the older private remote sensing operators like the Belgian company Eurosense have large collection of historical data.
These are stil in active use, like for tracing the history of contruction projects, track building vialations or for historical and archeological research or just for communication an illustration purposes. These are mostly used as they are without georerencing or mosaicing of color optimisations. This is not ideal for interpretaration purposes as you cannot overlay other mapping data on these photo's. Some effords were made create mosaics of these photo's, like:
- The 1971 panchromatic orthomosiac of Flanders made form a series of photo's flown by Eurosense for the Flemish goverment. -
- The 1955 orthomosaic of tbe city of Ghent by made the team city-archeology and Team data of the city Ghent derived from a "forgotten" photo colection found in the archives the departement of the public works of the Flemish government.
- The 1940-1940 orthomosaic of Antwerp derived from a heterogeneous collection of Allies and Axis aerial reconnaissance forces photo's. This series was collected as source material for the book Vergeten Linies 3 : Militair Erfgoed Binnen de Antwerpse Fortengordels Op Luchtfoto en Lidar. By prof. Gheyle Wouter, and Ignace Bourgeois, published by Provincie Antwerpen, 2018.. The processing of the data was done by the city of Antwerp, a lot procssing was needed to match and improved colors, remove clouds and artefacts and match resolutions.
Goals:¶
None of the previously mentioned mosaics was colorized and are only available in grayscale.
THe goal of this project is to create a tool to colorize these kind of mosaics while preserving resoltion and geographical metadata.
Most automated colorsiation algorithms are based on a convolutional neural networks (CNN), originally traditional models where used later these models included a Generative Adversarial Neural (GAN) trainig phase.
Some okder approaches include like "Colorful Image Colorization" by Ricard Zhnag (2016). This was a regular classifiaction CNN, no GAN just yet, so it could only hand a limited ammount of features.
The most succesfull approach has been DeOldify by Jason Antic (2018). This is the model hat is the basis for most commercial colorising software today.
Some newer models models are also based on image to image diffusion models like controlnet by Lvmin Zhang (2023) build on top of stable diffusion, but these have huge GPU demands, have big problem of hallucinating new information and destroying existing data. So these are not suited for our purposes.
All of these network's have been trained on "normal" photo's and not on aerial imagery and tend to perfom poorly when reconstructing a orthophoto. I'll need to train our own network.
Similar projects¶
I did some further research and found some similar projects that translate one type of geospatial imagery to fake orthophoto's.
- In this example from ESRI they generate fake a orthophoto from elevation-data: "Generating rgb imagery from surface elevation using Pix2Pix"
- This person made fake airial photo's from 19-century ordnance survey maps: "map2sat: Satellite Image Generation Conditioned on Maps, Generate Your Own Scotland" by Miguel Espinosa Et al.
- Both articles are based the original paper by Phillip Isola Et al. (2016): "Image-to-Image Translation with Conditional Adversarial Networks".
In this approach ESRI uses a CycleGAN instead to translate SAR image (in radio wave part ot the spectum) into colorised photos, this so this is explicily trained with unpaired data, forcing the model to find a broad "style" in the source data and transfer this to the target:
- "SAR to RGB image translation using CycleGAN"
- This article was based on "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" by Jun-Yan Zhu (2017).
Final approach¶
While these approaches are valid, one of the advantages of greyscale imagery is that you already have a part of the color, in CIELab-colorspacen the L-component is identical to the image in greyscale, so you only need to predict 2 values instead of 3 like when you ate translating a map or diffrent-part spectrum to an RGB-image. Ligthness (L) the value of a pixel in Grey is composed of the original RGB-values according to this formula:
$$ L = 0.30 \times R + 0.59 \times G + 0.11 \times B $$
I found this article "Colorizing black & white images with U-Net and conditional GAN" by Moein Shariatnia Published on November 18 2020 in Towards Data Science. In this arcticle he outlines how to predict the ab-values of greyscale image in CIELAB colorspace and then recompose it with the original image to create an RBG-image. For the predications he uses a variation of the U-net classification model, but modified to output 2 channels. He mostly used pytorch and torchvision and fast.ai for the implementation of his model he also uses scikit-image for colorspace manipulations, as torchvision is lacking in this regard.
I largly copied his code but I made several changes to fit it to training en infering on older geospatial orthophoto's, like a specific augmentation function and reading data with GDAL a library that preserves geospacial metadata when reading data and offers several utilities to deal with geodata, unlike PIL, torchvision or OpenCV.
The remainder of the code in this notebook goes though all the steps of creating this model. These are bit simplified over the full implementation, but are fully functional and will create model when, just not a very good one. If you want to try this yourselves you will need to donwload the source data as described in chapter Datasources.
Install the necessary libraries if needed¶
REMARK: check https://pytorch.org for the best option for your system, Also make a separate environment for your system.
To make an environment you can run the following in powershell:
python -m venv ortho_env
./Scripts/Activate.ps1
Then install the dependcies.
Alternatively you can also use docker to isolate you code, a Docker file is provided.
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
%pip install fastai
%pip install GDAL
%pip install timm
%pip install scikit-image
%pip install tqdm, IProgress
%pip install pillow
%pip install matplotlib
%pip install pandas
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from osgeo import gdal
gdal.UseExceptions() # needs to be called, so gdal will have readable exceptions
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imread
import time, datetime, os.path
from pathlib import Path
from typing import Iterable
from tqdm.notebook import tqdm
%matplotlib inline
import warnings
warnings.simplefilter("ignore", category=RuntimeWarning)
warnings.simplefilter('ignore', category=UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEBUG = True
Preprocessing¶
Obtaining data¶
Datasources¶
- RGB-colored Aerial photo's for training (Open data, from the Flemish Government):
- 2023: Meest recente beelden op moment van schhrijven https://download.vlaanderen.be/product/10426-orthofotomoza%C3%AFek-middenschalig-winteropnamen-kleur-2023-vlaanderen
- 2015: De beste resolutie beschikbaar https://download.vlaanderen.be/product/864-grootschalige_orthofotomozaieken
- 1979-1990: De oudste kluerenbeelden van Vlaanderen https://download.vlaanderen.be/product/602-orthofotomoza%C3%AFek-kleinschalig-zomeropnamen-kleur-1979-1990-vlaanderen
- Potential sources for real panchromatic aerial photo's for testing:
- NGI: http://www.cartesius.be/CartesiusPortal/
- Digitaal Vlaanderen: https://www.vlaanderen.be/datavindplaats/catalogus/orthofotomozaiek-kleinschalig-zomeropnamen-0
- City of Gent: https://stad.gent/nl/cultuur-vrije-tijd/cultuur/hoe-zag-jouw-buurt-eruit-de-jaren-50
- Panchromatic photos from the internal collections of the city of Antwerp https://felixarchief.antwerpen.be/archievenoverzicht/168417.
- Landuse: https://www.vlaanderen.be/datavindplaats/catalogus/bodembedekkingskaart-bbk-1m-resolutie-opname-2018
Tiling the sourcedata from JPEG2000 to Jpeg¶
Large scale mozaic's aree usualy stored as Jpeg2000 an extension of the jpg-format that support internal tiling and wavelet compression, torchvision can't read jpeg2000 and the files are to large to process as a single array.
GDAL has a fast tool to https://gdal.org/programs/gdal_retile.html this tool preserves the geographical properties and metadata in a csv-file, and allows also to write some overlap, this is nessary if you gonne apply convoltions on the data. Files with only nodata are ommitted.
For each jpeg2000 write a collection of tiles and index csv-file that describes these files. For example for the files of 2023, in powershell:
foreach ($G in ls W:\2023\*.jp2 ){
gdal_retile -co WORLDFILE=YES -overlap 100 -ps 512 512 -csv "W:\\2023_tiles\\$($G.Basename).csv" -f JPEG -ot Byte -targetDir "W:\\2023_tiles" $G.fullname
}
For the 0.15 meter resolution photo's of 2023 this gave me 2626657 files.
The making of a weigthed dataset¶
I will use the Flemish landuse dataset 'Bodem Bedekkings Kaart (BBK)" from the Flemish Goverment:
https://www.vlaanderen.be/datavindplaats/catalogus/bodembedekkingskaart-bbk-1m-resolutie-opname-2018
The BBK is a Segmentation map derived from the Multispectral (RGB-NIR) aerial images that shows the land cover in Flanders. Based on pixel classification supplemented with vectorial ground truth data, data is divided into low green, agricultural, forest, waterway and buildings. The target audience of these maps at resolution 1m and 5m is the general user who wants to consult a land cover map of Flanders as a basis for various analyzes relating to land cover or land use.
This map has 14 broad catergories + 0 for no data (usually not in Flanders):
- Buildings
- Roads
- Other constructed
- Railroads
- Water
- Other Natural
- Field - Agriculture
- Grass - Bushes (Nature)
- Trees (Nature)
- Grass Bushes (Agriculture)
- Grass Bushes (Road Edge)
- Trees (Road edge)
- Grass Bushes (Water Edge)
- Trees (Water Edge)

We sample for the extend of every photo for the dominant Category. We use the tool Zonal statistics is QGIS for this: https://docs.qgis.org/3.34/en/docs/user_manual/processing_algs/qgis/rasteranalysis.html#qgiszonalstatisticsfb
The result is saved to arrow-file.
Below you you see this file on of the traindata yellow for agriculture types, green for natural types, blue for water and reds for constructed types.

BBK_CATEGORIES = {
0: "Out of scope, onbekend",
1: "Gebouwen",
2: "Autowegen",
3: "Overig Afgedekt",
4: "Spoorwegen",
5: "Water",
6: "Overig Onafgedekt",
7: "Akker (Landbouw)",
8: "Gras Struiken (Groen)",
9: "Bomen (Groen)",
10: "Gras Struiken (Landbouw)",
11: "Gras Struiken (Wegrand)" ,
12: "Bomen (Wegrand)" ,
13: "Gras Struiken (Waterrand)",
14: "Bomen (Waterrand)"}
grouped_categories= {"UNKNOWN": [0], "BUILDING": [1], "ROADLIKE":[2,3,4], "GREEN": [6,8,10,11],
"AGRO": [7], "WOOD": [9,12], "WATER":[5,13,14] }
ds = pd.read_feather(r"W:\1989_tiles\index_landuse.arrow") # r"W:\2023_tiles\2023tiles_landuse.arrow") # r"W:\2015_tiles\2015tiles_landuse.arrow")
#group this to a few simpler categories
ds["CATEGORY"] = ds["BBK_CAT"].map({v: k for k,vv in grouped_categories.items() for v in vv})
ds["CATEGORY"] = ds["CATEGORY"].astype("category")
ds["BBK_CAT"] = ds["BBK_CAT"].astype("category")
ds["BBK_CAT"]= ds["BBK_CAT"].cat.rename_categories(BBK_CATEGORIES.values())
ds[["path", "BBK_CAT","CATEGORY"]].sample(5)
| path | BBK_CAT | CATEGORY | |
|---|---|---|---|
| 292613 | W:\2015_tiles\K127n\OGWRGB13_15VL_K127n_19_32.jpg | Gras Struiken (Landbouw) | GREEN |
| 717788 | W:\2015_tiles\K205n\OGWRGB13_15VL_K205n_14_68.jpg | Gras Struiken (Landbouw) | GREEN |
| 942831 | W:\2015_tiles\K243n\OGWRGB13_15VL_K243n_18_72.jpg | Bomen (Groen) | WOOD |
| 774678 | W:\2015_tiles\K214z\OGWRGB13_15VL_K214z_24_18.jpg | Akker (Landbouw) | AGRO |
| 509148 | W:\2015_tiles\K164n\OGWRGB13_15VL_K164n_11_39.jpg | Akker (Landbouw) | AGRO |
Also check if files exist on drive and remove files that don't exist or are moro then half black (=nodata).
ds = ds[ds["path"].apply(lambda x: os.path.exists(x) and np.median( imread(x) ) != 0 ) ]
ds.reset_index(inplace=True)
Calulate the percentage of occurrences of a CATEGORY
ax = (100* ds.CATEGORY.value_counts() / ds.CATEGORY.count()).plot(
edgecolor='#fff', kind='bar',alpha=0.9, rot=0,
color =['#0b0', '#ffff0e','#eee', '#964B00', '#d62728', '#0ff', '#f0f'])
ax.set_title('The percentage of occurrences of CATEGORY')
ax.set_xlabel(None)
ax.set_ylabel('%')
Text(0, 0.5, '%')
We can use these inverse value counts as weigths on these tiles. So the less a values common a value, the more likely it will be picked in weigthed sampling.
weights = 1/ds["CATEGORY"].value_counts()
weights
CATEGORY GREEN 0.000002 AGRO 0.000002 WOOD 0.000005 ROADLIKE 0.000016 BUILDING 0.000021 WATER 0.000039 Name: count, dtype: float64
# set as WEIGHT for readability and performance convert integer
ds['WEIGHT'] = ds['CATEGORY'].map( weights*10e6 ).astype('int32')
# you can use this field to sample this dataset in balanced manner,
# replace is False so a photo won't be picked twice.
ds.sample(5, weights='WEIGHT', replace=False)[['path',"BBK_CAT","CATEGORY","WEIGHT"]]
| path | BBK_CAT | CATEGORY | WEIGHT | |
|---|---|---|---|---|
| 179228 | W:\2015_tiles\K085z\OGWRGB13_15VL_K085z_25_01.jpg | Akker (Landbouw) | AGRO | 23 |
| 798295 | W:\2015_tiles\K217z\OGWRGB13_15VL_K217z_39_01.jpg | Akker (Landbouw) | AGRO | 23 |
| 884467 | W:\2015_tiles\K233n\OGWRGB13_15VL_K233n_39_61.jpg | Gras Struiken (Waterrand) | WATER | 387 |
| 1093354 | W:\2015_tiles\K267z\OGWRGB13_15VL_K267z_10_38.jpg | Akker (Landbouw) | AGRO | 23 |
| 86016 | W:\2015_tiles\K065z\OGWRGB13_15VL_K065z_45_22.jpg | Akker (Landbouw) | AGRO | 23 |
# since water is mostly featureless, lets lower its weight a bit more
ds.loc[ds.CATEGORY == 'WATER', 'WEIGHT'] = ds.loc[ds.CATEGORY == 'WATER', 'WEIGHT'] // 3
# as well for unknown
ds.loc[ds.CATEGORY == 'UNKNOWN', 'WEIGHT'] = ds.loc[ds.CATEGORY == 'UNKNOWN', 'WEIGHT'] // 3
ds.WEIGHT.unique()
array([ 23, 46, 20, 160, 210, 129])
ds.sample(10, weights='WEIGHT', replace=False)[['path',"BBK_CAT","CATEGORY","WEIGHT"]]
| path | BBK_CAT | CATEGORY | WEIGHT | |
|---|---|---|---|---|
| 369338 | W:\2015_tiles\K141n\OGWRGB13_15VL_K141n_43_05.jpg | Akker (Landbouw) | AGRO | 23 |
| 1274885 | W:\2015_tiles\K305z\OGWRGB13_15VL_K305z_40_04.jpg | Bomen (Groen) | WOOD | 46 |
| 783289 | W:\2015_tiles\K215z\OGWRGB13_15VL_K215z_38_49.jpg | Gras Struiken (Groen) | GREEN | 20 |
| 872012 | W:\2015_tiles\K231z\OGWRGB13_15VL_K231z_24_08.jpg | Water | WATER | 129 |
| 474967 | W:\2015_tiles\K157z\OGWRGB13_15VL_K157z_05_22.jpg | Gebouwen | BUILDING | 210 |
| 255195 | W:\2015_tiles\K118z\OGWRGB13_15VL_K118z_29_41.jpg | Akker (Landbouw) | AGRO | 23 |
| 1129828 | W:\2015_tiles\K282z\OGWRGB13_15VL_K282z_04_27.jpg | Overig Afgedekt | ROADLIKE | 160 |
| 1205696 | W:\2015_tiles\K294z\OGWRGB13_15VL_K294z_17_01.jpg | Gras Struiken (Landbouw) | GREEN | 20 |
| 982207 | W:\2015_tiles\K248n\OGWRGB13_15VL_K248n_43_58.jpg | Gras Struiken (Landbouw) | GREEN | 20 |
| 582338 | W:\2015_tiles\K175z\OGWRGB13_15VL_K175z_37_65.jpg | Water | WATER | 129 |
Save the results back to an arrow file. Then repeat this flow for 2015 and 2023.
ds[['path',"BBK_CAT","CATEGORY","WEIGHT"]].to_feather('.\\data\\tiles_1989_weighted.arrow', compression='lz4')
Merging Results¶
The resulting arrow files saved are added to project.
I had about 60000 images from 1989 at a ground resolution of 1m, 1,65 million form 2015 at resolution 25cm and 2,62 million from 2023 at a resolution 15cm. While the resolution of 2023 is higher then 2015 the quality is lower. So I also reduce its importance a bit.
The dataset form 1989 is derived from analog color images, while those from 2015 and 2023 are taken digitally.
df1989 = pd.read_feather("data\\tiles_1989_weighted.arrow")
df1989.sample(4, weights='WEIGHT')
| path | BBK_CAT | CATEGORY | WEIGHT | |
|---|---|---|---|---|
| 17861 | W:\1989_tiles\OKZRGB79_90VL_K16\OKZRGB79_90VL_... | Gras Struiken (Groen) | GREEN | 5 |
| 58613 | W:\1989_tiles\OKZRGB79_90VL_K39\OKZRGB79_90VL_... | Akker (Landbouw) | AGRO | 4 |
| 45809 | W:\1989_tiles\OKZRGB79_90VL_K29\OKZRGB79_90VL_... | Autowegen | ROADLIKE | 69 |
| 21418 | W:\1989_tiles\OKZRGB79_90VL_K17\OKZRGB79_90VL_... | Overig Afgedekt | ROADLIKE | 69 |
df2015 = pd.read_feather("data\\tiles_2015_weighted.arrow")
df2015.sample(4, weights='WEIGHT')
| path | BBK_CAT | CATEGORY | WEIGHT | |
|---|---|---|---|---|
| 819860 | W:\2015_tiles\K222z\OGWRGB13_15VL_K222z_27_38.jpg | Akker (Landbouw) | AGRO | 23 |
| 571965 | W:\2015_tiles\K174n\OGWRGB13_15VL_K174n_48_66.jpg | Spoorwegen | ROADLIKE | 160 |
| 1455586 | W:\2015_tiles\K336n\OGWRGB13_15VL_K336n_04_57.jpg | Gebouwen | BUILDING | 210 |
| 386341 | W:\2015_tiles\K143z\OGWRGB13_15VL_K143z_21_04.jpg | Akker (Landbouw) | AGRO | 23 |
df2023 = pd.read_feather("data\\tiles_2023_weighted.arrow")
df2023.sample(4, weights='WEIGHT')
| index | path | BBK_CAT | CATEGORY | WEIGHT | |
|---|---|---|---|---|---|
| 2038408 | 2281336 | W:\2023_tiles\K327n\OMWRGBMRVL_K327n_58_80.jpg | Akker (Landbouw) | AGRO | 16 |
| 1709589 | 1933173 | W:\2023_tiles\K293n\OMWRGBMRVL_K293n_14_01.jpg | Overig Afgedekt | ROADLIKE | 52 |
| 887226 | 1041295 | W:\2023_tiles\K187n\OMWRGBMRVL_K187n_57_14.jpg | Akker (Landbouw) | AGRO | 16 |
| 617578 | 760235 | W:\2023_tiles\K157n\OMWRGBMRVL_K157n_08_15.jpg | Water | WATER | 38 |
df2023.WEIGHT = df2023.WEIGHT //2
df1989.WEIGHT = df2023.WEIGHT *30
df = pd.concat([df1989, df2015, df2023], ignore_index=True)
df[["path","BBK_CAT","CATEGORY","WEIGHT"]].sample(25, weights='WEIGHT', replace=False)
| path | BBK_CAT | CATEGORY | WEIGHT | |
|---|---|---|---|---|
| 2574101 | W:\2023_tiles\K231z\OMWRGBMRVL_K231z_27_21.jpg | Gebouwen | BUILDING | 44 |
| 29305 | W:\1989_tiles\OKZRGB79_90VL_K21\OKZRGB79_90VL_... | Akker (Landbouw) | AGRO | 1320 |
| 494807 | W:\2015_tiles\K173z\OGWRGB13_15VL_K173z_04_54.jpg | Overig Afgedekt | ROADLIKE | 160 |
| 291307 | W:\2015_tiles\K136z\OGWRGB13_15VL_K136z_43_32.jpg | Bomen (Groen) | WOOD | 46 |
| 1853191 | W:\2023_tiles\K146n\OMWRGBMRVL_K146n_59_30.jpg | Water | WATER | 19 |
| 2659380 | W:\2023_tiles\K238z\OMWRGBMRVL_K238z_20_61.jpg | Akker (Landbouw) | AGRO | 8 |
| 657942 | W:\2015_tiles\K212z\OGWRGB13_15VL_K212z_07_56.jpg | Gras Struiken (Landbouw) | GREEN | 20 |
| 810922 | W:\2015_tiles\K236z\OGWRGB13_15VL_K236z_48_78.jpg | Gras Struiken (Groen) | GREEN | 20 |
| 28133 | W:\1989_tiles\OKZRGB79_90VL_K21\OKZRGB79_90VL_... | Akker (Landbouw) | AGRO | 570 |
| 1264204 | W:\2015_tiles\K334n\OGWRGB13_15VL_K334n_32_44.jpg | Bomen (Groen) | WOOD | 46 |
| 448829 | W:\2015_tiles\K165n\OGWRGB13_15VL_K165n_34_68.jpg | Bomen (Groen) | WOOD | 46 |
| 1713048 | W:\2023_tiles\K131z\OMWRGBMRVL_K131z_51_71.jpg | Gebouwen | BUILDING | 44 |
| 3588204 | W:\2023_tiles\K413n\OMWRGBMRVL_K413n_09_33.jpg | Akker (Landbouw) | AGRO | 8 |
| 187947 | W:\2015_tiles\K096n\OGWRGB13_15VL_K096n_09_06.jpg | Bomen (Groen) | WOOD | 46 |
| 518381 | W:\2015_tiles\K176z\OGWRGB13_15VL_K176z_19_33.jpg | Gras Struiken (Landbouw) | GREEN | 20 |
| 2054616 | W:\2023_tiles\K166z\OMWRGBMRVL_K166z_48_59.jpg | Gebouwen | BUILDING | 44 |
| 2590652 | W:\2023_tiles\K233n\OMWRGBMRVL_K233n_08_39.jpg | Water | WATER | 19 |
| 2737742 | W:\2023_tiles\K247n\OMWRGBMRVL_K247n_06_15.jpg | Gras Struiken (Groen) | GREEN | 6 |
| 940628 | W:\2015_tiles\K258n\OGWRGB13_15VL_K258n_30_61.jpg | Bomen (Groen) | WOOD | 46 |
| 538364 | W:\2015_tiles\K181n\OGWRGB13_15VL_K181n_40_35.jpg | Bomen (Groen) | WOOD | 46 |
| 12086 | W:\1989_tiles\OKZRGB79_90VL_K13\OKZRGB79_90VL_... | Akker (Landbouw) | AGRO | 390 |
| 16443 | W:\1989_tiles\OKZRGB79_90VL_K15\OKZRGB79_90VL_... | Water | WATER | 270 |
| 2170028 | W:\2023_tiles\K178n\OMWRGBMRVL_K178n_36_37.jpg | Autowegen | ROADLIKE | 26 |
| 675418 | W:\2015_tiles\K214z\OGWRGB13_15VL_K214z_39_60.jpg | Gras Struiken (Groen) | GREEN | 20 |
| 3511 | W:\1989_tiles\OKZRGB79_90VL_K07\OKZRGB79_90VL_... | Bomen (Groen) | WOOD | 390 |
df.to_feather("data\\tiles_merged.arrow", compression='lz4')
DATASET¶
def makeWeightedDataFromArrow(arrow, train_size=4000, test_size=1000,
pathField='NAME', weightField='WEIGHT', replacement=False):
ds= pd.read_feather(arrow)
train_paths = ds.sample(train_size, weights=weightField, replace=replacement)[pathField]
test_paths = ds.sample(test_size, weights=weightField, replace=replacement)[pathField]
return list(train_paths), list(test_paths)
train_paths, val_paths = makeWeightedDataFromArrow(
r'.\data\tiles_2015_Weighted.arrow', train_size=160, test_size=40,
pathField='path', weightField='WEIGHT', replacement=False)
print(len(train_paths), len(val_paths))
160 40
Augmentation¶
from skimage.exposure import adjust_gamma, adjust_sigmoid
from skimage.util import random_noise
from skimage.filters import gaussian
from skimage.transform import resize
def grainify(img:np.ndarray):
"Make it grainy like an old photo"
c, rows, cols = img.shape
val = np.random.uniform(0.036, 0.107)**2
# Full resolution
noise_1 = np.zeros((rows, cols))
noise_1 = random_noise(noise_1, mode='gaussian', var=val, clip=False)
# # Half resolution
noise_2 = np.zeros((rows//2, cols//2))
noise_2 = random_noise(noise_2, mode='gaussian', var=(val*2)**2, clip=False)
noise_2 = resize(noise_2, (rows, cols)) # Upscale to original image size
noise = noise_1 + noise_2
noise = np.stack( [noise]*c, axis=0)
noisy_img = img/255 + noise # Add noise_im to the input image.
return np.round((255 * noisy_img)).clip(0, 255).astype(np.uint8)
def aug(img:np.ndarray):
"some data augmentation on img "
img = adjust_gamma(img, gamma=np.random.uniform(low=0.5, high=1.5) ) #change lighting
img = grainify(img) #Make it grainy like an old photo
img = adjust_sigmoid(img, gain= np.random.uniform(1,10) ) #contrast
img = gaussian(img, np.random.uniform(0,1.5), channel_axis=0 ) # blur
return (img*255).astype('uint8')
# Load random image using GDAL
imagePath = str( Path(r"W:\1979_tiles\OKZRGB79_90VL_K23\OKZRGB79_90VL_K23_0_9.png") )
image = gdal.Open(imagePath).ReadAsArray(buf_xsize=512, buf_ysize=512) #1pix == 1 meter
# Apply the transformation to your image
aug_image = aug(image)
fig,axs= plt.subplots(1,3)
axs[0].imshow( image.transpose((1, 2, 0)) )
axs[0].set_title('Original')
axs[1].imshow( aug_image.transpose((1, 2, 0)) )
axs[1].set_title('Random augmentation')
axs[2].imshow( rgb2gray(aug_image , channel_axis=0 ) , cmap='Greys_r' )
axs[2].set_title('Grayscale')
fig.set_size_inches(14,42)
fig.tight_layout()
Creating a iterable dataset-object from the sourcedata that is readable as a tensor in pytorch¶
class ColorizationDataset(Dataset):
def __init__(self, paths:Iterable[os.PathLike], imsize:int=256, rootDir:str='', resize:bool=True):
super().__init__()
self.size = imsize
self.paths = paths
self.root = rootDir
self.resize = resize
def __getitem__(self, idx:int):
imagePath = str(Path(self.root) / self.paths[idx])
img = gdal.Open(imagePath).ReadAsArray(buf_xsize=self.size, buf_ysize=self.size)
img = aug(img)
img_lab = rgb2lab( img , channel_axis=0 ) # Converting RGB to L*a*b
img_lab = torch.tensor(img_lab, dtype=torch.float32) # Convert to Tensor
L = (img_lab[0] / 50. - 1).unsqueeze(0) # Between -1 and 1
ab = img_lab[1:3] / 110 # Between -1 and 1
return {'L': L, 'ab': ab}
def __len__(self):
return len(self.paths)
pretrain_dl = DataLoader(ColorizationDataset(train_paths, imsize=256), batch_size=8)
val_dl = DataLoader(ColorizationDataset(val_paths, imsize=256), batch_size=4)
data = next(iter(val_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(pretrain_dl), len(val_dl), len(train_paths), len(val_paths))
torch.Size([4, 1, 256, 256]) torch.Size([4, 2, 256, 256]) 20 10 160 40
Generator¶
The generator is the main modal that wil do classifaction task to calcute the ab-values from a panchromatic image.
I use an existing model "resnet18" that we can donwload with TIMM and convert it to a Dynamic UNet with fast.ai.
from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
import timm
def ResUnet(n_input=1, n_output=2, size=224, timm_model_name='resnet18'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(timm_model_name, pretrained=True)
body = create_body(model, pretrained=True, n_in=n_input, cut=-2)
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
return net_G
Pretrain Generator¶
You can pretrain the generator a little by running a few image trough it,
but without a discriminator this will not deliver good results.
class statsMeter:
def __init__(self):
self.reset()
def reset(self):
self.count, self.avg, self.sum = [0.] * 3
def update(self, val, count=1):
self.count += count
self.sum += count * val
self.avg = self.sum / self.count
def pretrain_generator(net_G, pretrain_dl, epochs, lrate=1e-3):
opt = optim.Adam(net_G.parameters(), lr=lrate)
criterion = nn.L1Loss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Started pretraining at {datetime.datetime.now()}")
for e in range(epochs):
loss_meter = statsMeter()
for data in pretrain_dl:
L, ab = data['L'].to(device), data['ab'].to(device)
preds = net_G(L)
loss = criterion(preds, ab)
opt.zero_grad()
loss.backward()
opt.step()
loss_meter.update(loss.item(), L.size(0))
print(f"Epoch {e + 1}/{epochs}")
print(f"L1 Loss: {loss_meter.avg:.5f}")
print(f"Finished pretraining at {datetime.datetime.now()}")
return net_G
net_G = ResUnet(n_input=1, n_output=2, size=256)
net_G = pretrain_generator(net_G, pretrain_dl, epochs=5)
torch.save(net_G.state_dict(), "runs\\demo\\res18-unet_demo.pt")
Started pretraining at 2024-01-03 11:32:33.917362 Epoch 1/5 L1 Loss: 5204.15472 Epoch 2/5 L1 Loss: 0.02970 Epoch 3/5 L1 Loss: 0.00766 Epoch 4/5 L1 Loss: 0.00620 Epoch 5/5 L1 Loss: 0.00640 Finished pretraining at 2024-01-03 11:33:11.607399
Test Unet before GAN training¶
We see it produces image that look kind of sepia.
img_test = list( Path(r"W:\testdata\tiles_1950_gray").glob('*.png') )
randImg = lambda: str( img_test[ np.random.randint(0, len(img_test) ) ])
resunet = ResUnet()
state_dict = torch.load(Path(".\\runs\\demo\\res18-unet_demo.pt"), map_location=device)
resunet.load_state_dict(state_dict)
<All keys matched successfully>
from model.tools import lab_to_rgb
testSet = []
for i in range(4):
ds =gdal.Open( randImg())
b1 = torch.Tensor( ( ds.GetRasterBand(1).ReadAsArray() /128) -1 ).unsqueeze(0)
testSet.append(b1.unsqueeze(0))
testSet = torch.cat(testSet)
f, axs= plt.subplots(4,2)
axs[0][0].imshow(testSet[0][0], cmap='Greys_r')
axs[1][0].imshow(testSet[1][0], cmap='Greys_r')
axs[2][0].imshow(testSet[2][0], cmap='Greys_r')
axs[3][0].imshow(testSet[3][0], cmap='Greys_r')
with torch.inference_mode():
w= resunet(testSet.to(device))
colorized = lab_to_rgb(testSet, w.cpu())
axs[0][1].imshow(colorized[0])
axs[1][1].imshow(colorized[1])
axs[2][1].imshow(colorized[2])
axs[3][1].imshow(colorized[3])
f.set_size_inches(4,8)
f.tight_layout()
Patch Discriminator¶
This is the part of the GAN-model that wil act as the Adversary, and should become much better at distinguishing between real end colorize images then a regular L1loss like we used in the pretraining.
class PatchDiscriminator(nn.Module):
def __init__(self, input_c, num_filters=64, n_down=3):
super().__init__()
model = [self.get_layers(input_c, num_filters, norm=False)]
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
for i in range(n_down)]
model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)]
self.model = nn.Sequential(*model)
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]
if norm: layers += [nn.BatchNorm2d(nf)]
if act: layers += [nn.LeakyReLU(0.2, True)]
return nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
GAN Loss¶
Use the output of the discriminator model calculate a Loss for the generator. The Loss of the GAN's generator calculated by a binary cross-entropy loss between the discriminator's output and the label data.
class GANLoss(nn.Module):
def __init__(self, real_label=1, fake_label=0):
super().__init__()
self.register_buffer('real_label', torch.tensor(real_label))
self.register_buffer('fake_label', torch.tensor(fake_label))
self.loss = nn.BCEWithLogitsLoss()
def get_labels(self, preds, target_is_real):
if target_is_real:
labels = self.real_label
else:
labels = self.fake_label
return labels.expand_as(preds)
def __call__(self, preds, target_is_real):
labels = self.get_labels(preds, target_is_real)
loss = self.loss(preds, labels)
return loss
Main Model: bringing it all together¶
This is the final A Generative Adversarial Network (GAN) cosisting of both the U-net generator and the discriminator.
class MainModel(nn.Module):
def __init__(self, lr_G=2e-4, lr_D=2e-4,
beta1=0.5, beta2=0.999, lambda_L1=100.):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.lambda_L1 = lambda_L1
self.net_G = ResUnet().to(self.device) # Generator
self.net_D = self.init_weights( # Discriminator
PatchDiscriminator(input_c=3, n_down=3, num_filters=64)).to(self.device)
self.GANcriterion = GANLoss(torch.float32).to(self.device)
self.L1criterion = nn.L1Loss()
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
def init_weights(self, net, gain:float=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and 'Conv' in classname:
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif 'BatchNorm2d' in classname:
nn.init.normal_(m.weight.data, 1., gain)
nn.init.constant_(m.bias.data, 0.)
net.apply(init_func)
return net
def set_requires_grad(self, model, requires_grad=True):
for p in model.parameters():
p.requires_grad = requires_grad
def setup_input(self, data):
self.L = data['L'].to(self.device)
self.ab = data['ab'].to(self.device)
def forward(self):
self.fake_color = self.net_G(self.L)
def backward_D(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image.detach())
self.loss_D_fake = self.GANcriterion(fake_preds, False)
real_image = torch.cat([self.L, self.ab], dim=1)
real_preds = self.net_D(real_image)
self.loss_D_real = self.GANcriterion(real_preds, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
def backward_G(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image)
self.loss_G_GAN = self.GANcriterion(fake_preds, True)
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
def optimize(self):
self.forward()
self.net_D.train()
self.set_requires_grad(self.net_D, True)
self.opt_D.zero_grad()
self.backward_D()
self.opt_D.step()
self.net_G.train()
self.set_requires_grad(self.net_D, False)
self.opt_G.zero_grad()
self.backward_G()
self.opt_G.step()
Train¶
Utils¶
Store stats, covert CIELAB data to RGB, visualize results.
def create_loss_meters():
loss_D_fake = statsMeter()
loss_D_real = statsMeter()
loss_D = statsMeter()
loss_G_GAN = statsMeter()
loss_G_L1 = statsMeter()
loss_G = statsMeter()
return {'loss_D_fake': loss_D_fake,
'loss_D_real': loss_D_real,
'loss_D': loss_D,
'loss_G_GAN': loss_G_GAN,
'loss_G_L1': loss_G_L1,
'loss_G': loss_G}
def update_losses(model, loss_meter_dict, count):
for loss_name, loss_meter in loss_meter_dict.items():
loss = getattr(model, loss_name)
loss_meter.update(loss.item(), count=count)
def lab_to_rgb(L, ab):
"""
Takes a batch of images
"""
L = (L + 1.) * 50.
ab = ab * 110.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
def visualize(model, data):
model.net_G.eval()
with torch.no_grad():
model.setup_input(data)
model.forward()
model.net_G.train()
fake_color = model.fake_color.detach()
real_color = model.ab
L = model.L
fake_imgs = lab_to_rgb(L, fake_color)
real_imgs = lab_to_rgb(L, real_color)
for i in range(4):
ax = plt.subplot(3, 5, i + 1)
ax.imshow(L[i][0].cpu(), cmap='gray')
ax.axis("off")
ax = plt.subplot(3, 5, i + 1 + 5)
ax.imshow(fake_imgs[i])
ax.axis("off")
ax = plt.subplot(3, 5, i + 1 + 10)
ax.imshow(real_imgs[i])
ax.axis("off")
plt.show()
def log_results(loss_meter_dict):
for loss_name, loss_meter in loss_meter_dict.items():
print(f"{loss_name}: {loss_meter.avg:.4f}")
Training¶
def train_model(model, train_dl, epochs):
model.train()
# getting a batch for visualizing the model output after fixed intrvals
for e in range(epochs):
loss_meter_dict = create_loss_meters() # function returing a dictionary of objects to
for data in train_dl:
model.setup_input(data)
model.optimize()
batchsize = data['L'].size(0)
update_losses(model, loss_meter_dict, count=batchsize) # function updating the log objects
print(f"\nEpoch {e+1}/{epochs}")
log_results(loss_meter_dict) # function to print out the losses
test_data = next(iter(val_dl))
visualize(model, test_data) # function displaying the model's outputs
model = MainModel()
train_model(model, pretrain_dl, 5)
model.eval()
torch.save(model.state_dict(), 'runs\\models\\testRun.pth')
Epoch 1/5 loss_D_fake: 0.4419 loss_D_real: 0.4319 loss_D: 0.4369 loss_G_GAN: 1.3350 loss_G_L1: 4.6299 loss_G: 5.9649
Epoch 2/5 loss_D_fake: 0.4540 loss_D_real: 0.4412 loss_D: 0.4476 loss_G_GAN: 1.3924 loss_G_L1: 4.9903 loss_G: 6.3826
Epoch 3/5 loss_D_fake: 0.6094 loss_D_real: 0.5500 loss_D: 0.5797 loss_G_GAN: 1.2045 loss_G_L1: 4.9443 loss_G: 6.1488
Epoch 4/5 loss_D_fake: 0.5511 loss_D_real: 0.5669 loss_D: 0.5590 loss_G_GAN: 1.1266 loss_G_L1: 4.6503 loss_G: 5.7769
Epoch 5/5 loss_D_fake: 0.5397 loss_D_real: 0.5357 loss_D: 0.5377 loss_G_GAN: 1.1587 loss_G_L1: 4.5650 loss_G: 5.7236
Training in production:¶
In the example above I used Resnet18, in reality I used the much larger resnet34 model.
To reduce memory-use I used Half-precision floating-point tensors and I also used Huggingface accelerate to speed up training. This could also allows us to train on multiple GPU's.
I trained for about 12 hours on a single RTX-3080 NIVIA GPU.
Inference¶
I use the results of the training to colorize real panchromatic images, not colored image, not color image made greyscale.
The proof of the pudding is in the eating.
import torch, numpy as np, matplotlib.pyplot as plt
from osgeo import gdal
gdal.UseExceptions()
from model.unet import ResUnet
from model.tools import lab_to_rgb
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
W= ".\\runs\\models\\run32\\color_run32_resnet34_512_net_G40.pth"
model= ResUnet( size=512 , timm_model_name='resnet50')
model.to(device)
model.load_state_dict( torch.load(W, map_location=device ) )
<All keys matched successfully>
We use an JPEG-2000 image from 1970 obtained from the Flemish Government and Geotiff from 1948 obtained from the NGI.
Both image have a ground resolation of about 1 meter.
We let numpy pick a random location of 512 by 512 pixels
imsize = 512
# two large scale black&white orthophoto mosaics of southern Antwerp:
ds0 = gdal.Open( "W:\\1970\\OKZPAN71VL_K15.jp2" ) # 1970 1m resolution
ds1 = gdal.Open( "W:\\1948-1968\\1948antwZuid.tif" ) # 1948 1m resolution
# pick a random place on the photo's of 512x512 pixels
ds0_img = ds0.GetRasterBand(1).ReadAsArray(
xoff= np.random.randint(ds0.RasterXSize - imsize),
yoff= np.random.randint(ds0.RasterYSize - imsize),
win_xsize=imsize, win_ysize=imsize)
ds1_img = ds1.GetRasterBand(1).ReadAsArray(
xoff= np.random.randint(ds1.RasterXSize - imsize),
yoff= np.random.randint(ds1.RasterYSize - imsize),
win_xsize=imsize, win_ysize=imsize)
g0_img = torch.Tensor( ( ds0_img /128) -1 ).unsqueeze(0)
g1_img = torch.Tensor( ( ds1_img /128) -1 ).unsqueeze(0)
with torch.inference_mode():
pred0 = model(g0_img.unsqueeze(0).to(device) )
pred1 = model(g1_img.unsqueeze(0).to(device) )
colorized0 = lab_to_rgb(g0_img.unsqueeze(0), pred0.cpu())[0]
colorized1 = lab_to_rgb(g1_img.unsqueeze(0), pred1.cpu())[0]
f , axs = plt.subplots(2,2)
axs[0,0].imshow(g0_img[0], cmap='Greys_r')
axs[0,0].set_title('input')
axs[0,1].imshow(colorized0)
axs[0,1].set_title("colorized")
axs[1,0].imshow(g1_img[0], cmap='Greys_r')
axs[1,1].imshow(colorized1)
f.set_size_inches(15,15)
f.tight_layout()